import torch
import numpy as np
import torch.nn as nn
from model.Physics_Attention import Physics_Attention_Structured_Mesh_2D
from radon_transform import RadonOperator

ACTIVATION_FUNCTIONS = {
    'gelu': nn.GELU,
    'tanh': nn.Tanh,
    'sigmoid': nn.Sigmoid,
    'relu': nn.ReLU,
    'leaky_relu': nn.LeakyReLU(0.1),
    'softplus': nn.Softplus,
    'elu': nn.ELU,
    'silu': nn.SiLU
}

class MultiLayerPerceptron(nn.Module):
    """A configurable multi-layer perceptron with residual connections."""
    
    def __init__(self, input_size, hidden_size, output_size, num_layers=1, 
                 activation_type='gelu', enable_residual=True):
        super(MultiLayerPerceptron, self).__init__()
        
        if activation_type in ACTIVATION_FUNCTIONS:
            activation_fn = ACTIVATION_FUNCTIONS[activation_type]
        else:
            raise NotImplementedError(f"Unsupported activation: {activation_type}")
            
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.enable_residual = enable_residual
        
        self.input_layer = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            activation_fn()
        )
        self.output_layer = nn.Linear(hidden_size, output_size)
        self.hidden_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                activation_fn()
            ) for _ in range(num_layers)
        ])

    def forward(self, input_data):
        hidden_state = self.input_layer(input_data)
        
        for layer in self.hidden_layers:
            if self.enable_residual:
                hidden_state = layer(hidden_state) + hidden_state
            else:
                hidden_state = layer(hidden_state)
                
        output = self.output_layer(hidden_state)
        return output

class PhysicsAttentionBlock(nn.Module):
    """Transformer block incorporating physics-attention."""
    
    def __init__(
        self,
        num_heads: int,
        embedding_dim: int,
        dropout_rate: float,
        activation_type='gelu',
        mlp_scaling_factor=4,
        is_output_layer=False,
        output_dim=1,
        slice_count=128,
        grid_height=85,
        grid_width=85
    ):
        super().__init__()
        self.is_output_layer = is_output_layer
        
        self.norm_layer_1 = nn.LayerNorm(embedding_dim)
        self.attention = Physics_Attention_Structured_Mesh_2D(
            embedding_dim,
            heads=num_heads,
            dim_head=embedding_dim // num_heads,
            dropout=dropout_rate,
            slice_num=slice_count,
            H=grid_height,
            W=grid_width
        )
        
        self.norm_layer_2 = nn.LayerNorm(embedding_dim)
        self.feed_forward = MultiLayerPerceptron(
            embedding_dim,
            embedding_dim * mlp_scaling_factor,
            embedding_dim,
            num_layers=0,
            enable_residual=False,
            activation_type=activation_type
        )
        
        if self.is_output_layer:
            self.norm_layer_3 = nn.LayerNorm(embedding_dim)
            self.output_projection = nn.Linear(embedding_dim, output_dim)

    def forward(self, input_embedding):
        embedding = self.attention(self.norm_layer_1(input_embedding)) + input_embedding
        embedding = self.feed_forward(self.norm_layer_2(embedding)) + embedding
        
        if self.is_output_layer:
            return self.output_projection(self.norm_layer_3(embedding))
        return embedding

class RadonBlock(nn.Module):
    """Radon transform with adaptive weighting for pde reconstruction."""
    
    def __init__(self, num_projection_angles):
        super().__init__()
        self.num_projection_angles = num_projection_angles
        self.radon_operator = RadonOperator(
            thetas=np.linspace(0, np.pi, num_projection_angles),
            circle=False,
            device="cuda",
            filter_name="ramp"
        )
        self.weight_estimator = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(16, 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, input_field):
        batch_size, channels, width, height = input_field.shape
        sinogram = self.radon_operator.forward(input_field)
        weights = self.weight_estimator(sinogram.mean(dim=1, keepdim=True))
        weights = weights.mean(dim=2, keepdim=True)
        weighted_sinogram = sinogram * weights
        reconstructed_field = self.radon_operator.filter_backprojection(
            weighted_sinogram,
            original_shape=(width, height)
        )
        return input_field + reconstructed_field

class RNO(nn.Module):
    """RNO model with physics-attention and Radon transform."""
    
    def __init__(self,
                 num_projection_angles=32,
                 num_radon_blocks=1,
                 spatial_dim=1,
                 num_attention_layers=6,
                 embedding_dim=256,
                 dropout_rate=0.0,
                 num_attention_heads=8,
                 enable_time_input=False,
                 activation_type='gelu',
                 mlp_scaling_factor=1,
                 input_function_dim=1,
                 output_dim=1,
                 slice_count=32,
                 reference_grid_size=8,
                 use_unified_encoding=False,
                 grid_height=85,
                 grid_width=85):
        super(RNO, self).__init__()
        self.model_name = 'RNO'
        self.grid_height = grid_height
        self.grid_width = grid_width
        self.reference_grid_size = reference_grid_size
        self.num_projection_angles = num_projection_angles
        self.num_radon_blocks = num_radon_blocks
        self.output_dim = output_dim
        self.use_unified_encoding = use_unified_encoding
        
        input_dim = (input_function_dim + reference_grid_size * reference_grid_size 
                    if use_unified_encoding else input_function_dim + spatial_dim)
        
        self.input_encoder = MultiLayerPerceptron(
            input_dim,
            embedding_dim * 2,
            embedding_dim,
            num_layers=0,
            enable_residual=False,
            activation_type=activation_type
        )
        
        self.enable_time_input = enable_time_input
        self.embedding_dim = embedding_dim
        self.spatial_dim = spatial_dim
        
        if enable_time_input:
            self.time_encoder = nn.Sequential(
                nn.Linear(embedding_dim, embedding_dim),
                nn.SiLU(),
                nn.Linear(embedding_dim, embedding_dim)
            )
        
        self.attention_blocks = nn.ModuleList([
            PhysicsAttentionBlock(
                num_attention_heads=num_attention_heads,
                embedding_dim=embedding_dim,
                dropout_rate=dropout_rate,
                activation_type=activation_type,
                mlp_scaling_factor=mlp_scaling_factor,
                output_dim=output_dim,
                slice_count=slice_count,
                grid_height=grid_height,
                grid_width=grid_width,
                is_output_layer=(i == num_attention_layers - 1)
            ) for i in range(num_attention_layers)
        ])
        
        self.initialize_parameters()
        self.latent_embedding = nn.Parameter(
            (1 / embedding_dim) * torch.rand(embedding_dim, dtype=torch.float)
        )
        
        self.radon_transforms = nn.Sequential(*[
            RadonBlock(self.num_projection_angles)
            for _ in range(num_radon_blocks)
        ])

    def initialize_parameters(self):
        self.apply(self._init_parameters)

    def _init_parameters(self, module):
        if isinstance(module, nn.Linear):
            trunc_normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
            nn.init.constant_(module.bias, 0)
            nn.init.constant_(module.weight, 1.0)

    def generate_position_grid(self, batch_size=1):
        x_coords, y_coords = self.grid_height, self.grid_width
        grid_x = torch.tensor(np.linspace(0, 1, x_coords), dtype=torch.float)
        grid_x = grid_x.reshape(1, x_coords, 1, 1).repeat([batch_size, 1, y_coords, 1])
        grid_y = torch.tensor(np.linspace(0, 1, y_coords), dtype=torch.float)
        grid_y = grid_y.reshape(1, 1, y_coords, 1).repeat([batch_size, x_coords, 1, 1])
        full_grid = torch.cat((grid_x, grid_y), dim=-1).cuda()
        
        ref_x = torch.tensor(np.linspace(0, 1, self.reference_grid_size), dtype=torch.float)
        ref_x = ref_x.reshape(1, self.reference_grid_size, 1, 1).repeat(
            [batch_size, 1, self.reference_grid_size, 1]
        )
        ref_y = torch.tensor(np.linspace(0, 1, self.reference_grid_size), dtype=torch.float)
        ref_y = ref_y.reshape(1, 1, self.reference_grid_size, 1).repeat(
            [batch_size, self.reference_grid_size, 1, 1]
        )
        ref_grid = torch.cat((ref_x, ref_y), dim=-1).cuda()
        
        position_encoding = torch.sqrt(torch.sum(
            (full_grid[:, :, :, None, None, :] - ref_grid[:, None, None, :, :, :]) ** 2,
            dim=-1
        )).reshape(batch_size, x_coords, y_coords, 
                  self.reference_grid_size * self.reference_grid_size).contiguous()
        
        return position_encoding

    def forward(self, spatial_input, function_input, time_input=None):
        if self.use_unified_encoding:
            position_encoding = self.generate_position_grid(spatial_input.shape[0]).reshape(
                spatial_input.shape[0],
                self.grid_height * self.grid_width,
                self.reference_grid_size * self.reference_grid_size
            )
        else:
            position_encoding = spatial_input
        
        if function_input is not None:
            input_features = torch.cat((position_encoding, function_input), -1)
            features = self.input_encoder(input_features)
        else:
            features = self.input_encoder(position_encoding)
            features = features + self.latent_embedding[None, None, :]
        
        if time_input is not None:
            time_embedding = timestep_embedding(time_input, self.embedding_dim).repeat(
                1, position_encoding.shape[1], 1
            )
            time_embedding = self.time_encoder(time_embedding)
            features = features + time_embedding
        
        for attention_block in self.attention_blocks:
            features = attention_block(features)
        
        skip_features = features
        batch_size, num_points, channels = features.shape
        
        features = features.reshape(batch_size, self.grid_height, self.grid_width, channels
                                  ).contiguous().permute(0, 3, 1, 2).contiguous()
        
        for radon_transform in self.radon_transforms:
            features = radon_transform(features)
        
        features = features.permute(0, 2, 3, 1).contiguous().reshape(
            batch_size, -1, channels
        ).contiguous()
        
        features = features + skip_features
        
        return features